function [pcntvar,Xfeatures_contributions_total,Xfeatures_contributions_on_PCs,XExpected_contribution_total,XExpected_contribution_on_PCs,V,U,Vs,Us,Xfeatures_ordering, M, newV] = ...
    SVD_multivariate_analysis(Y,X,covs_forY,covs_forX,number_components)
% X, Y matrices of size n x P and n x Q (n number of subjects).
% Optional: covs_forY and covs_forX covariable matrices for Y and X.
% respectively.
% number_components: number of Principal Components (PC)
% OUTPUTS
% pcntvar: percents of explained cross-covariance for each pair of components
% Xfeatures_contributions_total: total contribution of each variable in X
% Xfeatures_contributions_on_PCs: PC-specific contribution of each variable in X
% XExpected_contribution_total: threshold for total contributions
% XExpected_contribution_on_PCs: thresholds for PC-specific contributions
% V and Vs "spatial" components and subject loadings of Y
% U and Us "spatial" components and subject loadings of X
% Xfeatures_ordering: Ordering of X features from maximum to minimum total contribution on PCs.
%--------------------------------------------------------------------------
% Based on Felix M. Carbonell's function SurfStatSVD. Modified by Yasser
% Iturria-Medina, to calculate contributions, on April 16th, 2020.
% And Adewale Q. to outputs (M and newV) for further permutation, December 2020

if nargin < 3, covs_forY = []; end
if nargin < 4, covs_forX = []; end
if nargin < 5 || isempty(number_components), number_components = 10; end

% Normalizing Y elements, to have an approximate normal distribution
for i = 1:size(Y,2), Y(:,i) = boxcox(Y(:,i) - min(Y(:,i)) + eps); end

% Normalizing X elements, to have an approximate normal distribution
for i = 1:size(X,2), X(:,i) = boxcox(X(:,i) - min(X(:,i)) + eps); end

% SVD analysis
[~,V,U,Vs,Us,pcntvar, M, newV] = SurfStatSVD(zscore(Y), [], covs_forY, number_components, zscore(X), [], covs_forX);
% Outputs
% V and Vs "spatial" components and subject loadings of Y
% U and Us "spatial" components and subject loadings of X
% pcntvar percents of explained cross-covariance for each pair of components

Xfeatures_contributions_on_PCs = ((100*(U.^2)./repmat(sum(U.^2,2),1,size(U,2))).*repmat(pcntvar',[1 size(U,2)]));
Xfeatures_contributions_total  = (100*(U'.^2)./repmat(sum(U'.^2,1),size(U',1),1))*pcntvar';
XExpected_contribution_on_PCs = sum(100*1/size(U',1)*repmat(pcntvar',[1 size(U,2)])); 
XExpected_contribution_total  = sum(100*1/size(U',1)*pcntvar);

[~,j] = sort(Xfeatures_contributions_total,'descend'); % ordering with regards total X features contribution
% figure; stem(pcntvar); title('Percent Explained Variance for Y and X');
for i = 1:length(pcntvar), comp_names{i} = ['Componet ' num2str(i)]; end
% xticks(1:length(pcntvar)); xticklabels(comp_names); set(gca,'XTickLabelRotation',45);

% figure; imagesc(Xfeatures_contributions_on_PCs(:,j)); colormap Jet; colorbar; title('Contributions of X on Y via each PC');
for i = 1:size(U,2), X_names{i} = ['X feature ' num2str(i)]; end
% xticks(1:size(U,2)); xticklabels(X_names(j)); set(gca,'XTickLabelRotation',45);

% figure; stem(Xfeatures_contributions_total(j)); colormap Jet; hold on; title('Contributions of X on Y across all PCs');
% plot((0:size(U,2))',XExpected_contribution_total*ones(size(U,2)+1,1),'r');
% axis([0  size(U,2) 0 max(Xfeatures_contributions_total)]);
% xticks(1:size(U,2)); xticklabels(X_names(j)); set(gca,'XTickLabelRotation',45);
Xfeatures_ordering = j;

return;

function [YY,V,U,Vs,Us,pcntvar, M, newV] = SurfStatSVD( Y, maskY, ZY_remove, c, X, maskX, ZX_remove )

%Principal Components Analysis (PCA).
%
% Usage: [ pcntvar, U, V ] = SurfStatPCA( Y [,mask [,X [,k] ] ] );
%
% Y    = n x v matrix or n x v x k array of data, v=#vertices,
%        or memory map of same. 
% mask = 1 x v vector, 1=inside, 0=outside, default is ones(1,v),  
%        i.e. the whole surface.
% X    = model formula of type term, or scalar, or n x p design matrix of 
%        p covariates for the linear model. The PCA is done on the v x v 
%        correlations of the residuals and the components are standardized 
%        to have unit standard deviation about zero. If X=0, nothing is       
%        removed. If X=1, the mean (over rows) is removed (default).
% c    = number of components in PCA, default 4.
% L = explained variance by each latent variable (PC)
%
% pcntvar = 1 x c vector of percent variance explained by the components.
% U       = n x c matrix of components for the rows (observations).
% V       = c x v x k array of components for the columns (vertices).

maxchunk=2^25;
if isnumeric(Y)
    [n,v,k]=size(Y);
    isnum=true;
else
    Ym=Y;
    sz=Ym.Format{2};
    if length(sz)==2
        sz=sz([2 1]);
        k=1;
    else
        sz=sz([3 1 2]);
        k=sz(3);
    end
    n=sz(1);
    v=sz(2);
    isnum=false;
end    

if nargin<2 || isempty(maskY)
    maskY = ones(1,v)>0;
end
if nargin<3 || isempty(ZY_remove)
    ZY_remove=1;
end
if nargin<4
    c=4;
end
if nargin<5
    X=[];
end
if nargin<6
    maskX=[];
end
if nargin<7 || isempty(ZX_remove)
    ZX_remove=1;
end

if isa(ZY_remove,'term')
    ZY_remove = double(ZY_remove);
end
if isa(ZX_remove,'term')
    ZX_remove = double(ZX_remove);
end

if size(ZY_remove,1)==1
    ZY_remove=repmat(ZY_remove,n,1);
end
if size(ZX_remove,1)==1
    ZX_remove=repmat(ZX_remove,n,1);
end

if isnum
    nc=1;
    chunk=v;
else
    nc=ceil(v*n*k/maxchunk);
    chunk=ceil(v/nc);
end
if ~isnum
    fprintf(1,'%s',[num2str(round(v*n*k*4/2^20)) ' Mb to PCA, % remaining: 100 ']);
    n10=floor(n/10);
end

YY = zeros(n);
for ic=1:nc
    if ~isnum && rem(ic,n10)==0
        fprintf(1,'%s',[num2str(round(100*(1-ic/nc))) ' ']);
    end
    v1=1+(ic-1)*chunk;
    v2=min(v1+chunk-1,v);
    if ~isnum
        if length(sz)==2
            Y=double(Ym.Data(1).Data(v1:v2,:)');
        else
            Y=double(permute(Ym.Data(1).Data(v1:v2,:,:),[3 1 2]));
        end
    end
    maskc = maskY(v1:v2);
    if k==1
        Y=double(Y(:,maskc));
    else
        Y=double(reshape(Y(:,maskc,:),n,sum(maskc)*k));
    end
    if any(ZY_remove(:)~=0)
        Y=Y-ZY_remove*(pinv(ZY_remove)*Y);
    end
    S = sum(Y.^2,1);
    Smhalf = (S>0)./sqrt(S+(S<=0));
    Smhalf = repmat(Smhalf,n,1);
    Y = Y.*Smhalf;
    YY = YY+Y*Y';
end
YY = YY/sum(maskY);
if ~isnum
    fprintf(1,'%s\n','Done');
end

if ~isempty(X)
    XX = SurfStatSVD(X,maskX,ZX_remove);
else
    XX = YY;
end

if nargout>1
    [UX,SX] = svd(XX);
    XXhalf = UX*diag(sqrt(diag(SX)))*UX';
    XY = XXhalf*YY*XXhalf;
    [Ahalf, L] = svd(XY);
    A = XXhalf*Ahalf;
    [ds,is] = sort(diag(L),'descend');
    pcntvar = ds(1:c)'/sum(ds)*100;
    tmpA = A'*YY*A;
    B = A*diag(diag(tmpA).^(-1/2));
    W = L.^(1/2);
    tmpB = YY*B*diag(diag(W).^(-1));
    B = B(:,is(1:c));
    tmpB = tmpB(:,is(1:c));
    %% For the first modality
    Vs = YY*B;
    V = zeros(c,v*k);
    if isnum
        V(:,repmat(maskY,1,k))= B'*Y;
    else
        fprintf(1,'%s',[num2str(round(v*n*k*4/2^20)) ' Mb to PCA, % remaining: 100 ']);
        for ic=1:nc
            if ~isnum && rem(ic,n10)==0
                fprintf(1,'%s',[num2str(round(100*(1-ic/nc))) ' ']);
            end
            v1=1+(ic-1)*chunk;
            v2=min(v1+chunk-1,v);
            if length(sz)==2
                Y=double(Ym.Data(1).Data(v1:v2,:)');
            else
                Y=double(permute(Ym.Data(1).Data(v1:v2,:,:),[3 1 2]));
            end
            maskc=maskY(v1:v2);
            if k==1
                Y=double(Y(:,maskc));
            else
                Y=double(reshape(Y(:,maskc,:),n,sum(maskc)*k));
            end
            if any(ZY_remove(:)~=0)
                Y=Y-ZY_remove*(pinv(ZY_remove)*Y);
            end
            S = sum(Y.^2);
            Smhalf = (S>0)./sqrt(S+(S<=0));
            Smhalf = repmat(Smhalf,n,1);
            Y = Y.*Smhalf;
            maskcc = false(1,v);
            maskcc(v1:v2) = maskc;
            if k>1
                maskcc = repmat(maskcc,1,k);
            end
            V(:,maskcc) = B'*Y;
        end
        fprintf(1,'%s\n','Done');
    end
    s=sign(abs(max(V(:,repmat(maskY,1,k)),[],2))-abs(min(V(:,repmat(maskY,1,k)),[],2)));
    sv=sqrt(mean(V(:,repmat(maskY,1,k)).^2,2));
    V(:,repmat(maskY,1,k))=diag(s./(sv+(sv<=0)).*(sv>0))*V(:,repmat(maskY,1,k));
    Vs=Vs*diag(s);
    if k>1
        V = reshape(V,c,v,k);
    end
    %% For the second modality
    if ~isempty(X)
        Us = B*diag(diag(W(is(1:c),is(1:c))));
        if isnumeric(X)
            [n,v,k]=size(X);
            isnum=true;
        else
            Ym=X;
            sz=Ym.Format{2};
            if length(sz)==2
                sz=sz([2 1]);
                k=1;
            else
                sz=sz([3 1 2]);
                k=sz(3);
            end
            v=sz(2);
            isnum=false;
        end
        U = zeros(c,v*k);
        if isempty(maskX)
             maskX =ones(1,v)>0;
        end
        if isnum
            nc=1;
            chunk=v;
        else
            nc=ceil(v*n*k/maxchunk);
            chunk=ceil(v/nc);
        end
        if ~isnum
            fprintf(1,'%s',[num2str(round(v*n*k*4/2^20)) ' Mb to PCA, % remaining: 100 ']);
            n10=floor(n/10);
        end
        for ic=1:nc
            if ~isnum && rem(ic,n10)==0
                fprintf(1,'%s',[num2str(round(100*(1-ic/nc))) ' ']);
            end
            v1=1+(ic-1)*chunk;
            v2=min(v1+chunk-1,v);
            if ~isnum
                if length(sz)==2
                    X=double(Ym.Data(1).Data(v1:v2,:)');
                else
                    X=double(permute(Ym.Data(1).Data(v1:v2,:,:),[3 1 2]));
                end
            end
            maskc=maskX(v1:v2);
            if k==1
                X=double(X(:,maskc));
            else
                X=double(reshape(X(:,maskc,:),n,sum(maskc)*k));
            end
            if any(ZX_remove(:)~=0)
                X=X-ZX_remove*(pinv(ZX_remove)*X);
            end
            S=sum(X.^2,1);
            Smhalf=(S>0)./sqrt(S+(S<=0));
            Smhalf = repmat(Smhalf,n,1);
            X = X.*Smhalf;
            maskcc = false(1,v);
            maskcc(v1:v2)=maskc;
            if k>1
                maskcc=repmat(maskcc,1,k);
            end
            U(:,maskcc)=tmpB'*X;
        end
        fprintf(1,'%s\n','Done');
        s=sign(abs(max(U(:,repmat(maskX,1,k)),[],2))-abs(min(U(:,repmat(maskX,1,k)),[],2)));
        sv=sqrt(mean(U(:,repmat(maskX,1,k)).^2,2));
        U(:,repmat(maskX,1,k))=diag(s./(sv+(sv<=0)).*(sv>0))*U(:,repmat(maskX,1,k));
        Us=Us*diag(s);
        if k>1
            U = reshape(U,c,v,k);
        end
    end
     M = sqrt(diag (ds(1:c)));
     M =  diag(M)/norm(diag(M),1);
     M=diag(M);   
     newV = Y'*B;
end

